{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "65ecaf02",
   "metadata": {},
   "source": [
    "# [5、深入剖析PyTorch DataLoader源码](https://www.bilibili.com/video/BV1kq4y1G75V?spm_id_from=333.788.player.switch&vd_source=cdd897fffb54b70b076681c3c4e4d45d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fbd903d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision as tv\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "359ae690",
   "metadata": {},
   "source": [
    "## Dataloader\n",
    "\n",
    "loader用于把单个样本组织成一个个批次，用于模型训练\n",
    "\n",
    "文档：\n",
    "`DataLoader` 结合了数据集和采样器，并提供对给定数据集的可迭代访问。\n",
    "\n",
    "`~torch.utils.data.DataLoader` 支持 map-style 和 iterable-style 的数据集，可进行单进程或多进程加载，自定义加载顺序以及可选的自动批处理（collation）和内存固定（memory pinning）。\n",
    "\n",
    "更多详情请参见 :py:mod:`torch.utils.data` 文档页面。\n",
    "\n",
    "**参数:**\n",
    "\n",
    "* **dataset** (`Dataset`)：从中加载数据的数据集。\n",
    "* **batch_size** (`int`, 可选)：每批加载多少个样本（默认值：`1`）。\n",
    "* **shuffle** (`bool`, 可选)：设置为 `True` 时，会在每个 epoch 重新打乱数据（默认值：`False`）。\n",
    "* **sampler** (`Sampler` or `Iterable`, 可选)：定义从数据集中抽取样本的策略。可以是任何实现了 `__len__` 的 `Iterable` 对象。如果指定了此参数，则 `shuffle` 不能被指定。\n",
    "* **batch_sampler** (`Sampler` or `Iterable`, 可选)：与 `sampler` 类似，但一次返回一批索引。此参数与 `batch_size`、`shuffle`、`sampler` 和 `drop_last` 互斥。\n",
    "* **num_workers** (`int`, 可选)：用于数据加载的子进程数量。`0` 表示在主进程中加载数据。（默认值：`0`）\n",
    "* **collate_fn** (`Callable`, 可选)：将样本列表合并成一个 mini-batch 的张量（Tensor）。当从 map-style 数据集进行批处理加载时使用。\n",
    "* **pin_memory** (`bool`, 可选)：如果为 `True`，数据加载器在返回张量之前会将其复制到设备/CUDA 的固定内存中。如果您的数据元素是自定义类型，或者您的 `collate_fn` 返回的是一个自定义类型的批次，请参阅下面的示例。\n",
    "* **drop_last** (`bool`, 可选)：如果数据集大小不能被批大小整除，则设置为 `True` 以删除最后一个不完整的批次。如果为 `False` 并且数据集大小不能被批大小整除，则最后一个批次将较小。（默认值：`False`）\n",
    "* **timeout** (`numeric`, 可选)：如果为正数，则表示从工作进程（worker）收集一批数据的超时值。应始终为非负数。（默认值：`0`）\n",
    "* **worker_init_fn** (`Callable`, 可选)：如果不是 `None`，这个函数将在每个工作子进程中被调用，以工作进程 id（一个在 `[0, num_workers - 1]` 范围内的整数）作为输入，在设定种子之后和数据加载之前执行。（默认值：`None`）\n",
    "* **multiprocessing_context** (`str` or `multiprocessing.context.BaseContext`, 可选)：如果为 `None`，将使用操作系统的默认 `multiprocessing context`_。（默认值：`None`）\n",
    "* **generator** (`torch.Generator`, 可选)：如果不是 `None`，这个随机数生成器（RNG）将被 RandomSampler 用于生成随机索引，并被 multiprocessing 用于为工作进程生成 `base_seed`。（默认值：`None`）\n",
    "* **prefetch_factor** (`int`, 可选, 仅限关键字参数)：每个工作进程提前加载的批次数。`2` 表示所有工作进程总共将预取 `2 * num_workers` 个批次。（默认值取决于 `num_workers` 的设置。如果 `num_workers=0`，默认值为 `None`。否则，如果 `num_workers > 0`，默认值为 `2`）。\n",
    "* **persistent_workers** (`bool`, 可选)：如果为 `True`，数据加载器在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程中的 `Dataset` 实例处于活动状态。（默认值：`False`）\n",
    "* **pin_memory_device** (`str`, 可选)：当 `pin_memory` 为 `True` 时，指定要固定内存的设备。如果未指定，将默认为当前的 `accelerator<accelerators>`。不推荐使用此参数，并计划在未来弃用。\n",
    "* **in_order** (`bool`, 可选)：如果为 `False`，数据加载器将不强制批次按先进先出的顺序返回。仅在 `num_workers > 0` 时适用。（默认值：`True`）\n",
    "\n",
    "\n",
    "如果使用 `spawn` 启动方法，`worker_init_fn` 不能是不可 pickle 的对象，例如 lambda 函数。有关 PyTorch 中多进程处理的更多详细信息，请参阅 `multiprocessing-best-practices`。\n",
    "\n",
    "`len(dataloader)` 的长度推算是基于所使用的采样器长度。当 `dataset` 是一个 `~torch.utils.data.IterableDataset` 时，它会返回一个基于 `len(dataset) / batch_size` 的估计值，并根据 `drop_last` 进行适当的四舍五入，而不管多进程加载的配置如何。这代表了 PyTorch 能做出的最佳猜测，因为 PyTorch 相信用户的 `dataset` 代码能正确处理多进程加载以避免数据重复。\n",
    "\n",
    "> 然而，如果数据分片（sharding）导致多个工作进程有不完整的最后一个批次，这个估计值仍然可能不准确，因为 (1) 一个原本完整的批次可能被分成多个批次，以及 (2) 当设置 `drop_last` 为 True 时，可能会丢弃超过一个批次大小的样本。不幸的是，PyTorch 通常无法检测到这种情况。\n",
    "\n",
    "> 有关这两种数据集类型以及 `~torch.utils.data.IterableDataset` 如何与 `多进程数据加载`_ 交互的更多详细信息，请参阅 `数据集类型`_。\n",
    "\n",
    "有关随机种子相关的问题，请参阅 `reproducibility`、`dataloader-workers-random-seed` 和 `data-loading-randomness` 的说明。\n",
    "\n",
    "将 `in_order` 设置为 `False` 可能会损害可复现性，并且在数据不平衡的情况下，可能导致馈送给训练器的数据分布出现偏差。\n",
    "\n",
    "_multiprocessing context:\n",
    "    [https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c78dae7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"mnist_dataset\", train=True, download=True, transform=ToTensor()\n",
    ")\n",
    "\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"mnist_dataset\", train=False, download=True, transform=ToTensor()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988f0a9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)\n",
    "test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63a4e02e",
   "metadata": {},
   "source": [
    "可以自定义 sample 用于采样，但是和 shuffle 冲突\n",
    "```py\n",
    "if sampler is not None and shuffle:\n",
    "    raise ValueError(\"sampler option is mutually exclusive with shuffle\")\n",
    "```\n",
    "\n",
    "默认 sampler 是 `RandomSampler` 和 `SequentialSampler`\n",
    "```py\n",
    "if sampler is None:  # give default samplers\n",
    "    if self._dataset_kind == _DatasetKind.Iterable:\n",
    "        # See NOTE [ Custom Samplers and IterableDataset ]\n",
    "        sampler = _InfiniteConstantSampler()\n",
    "    else:  # map-style\n",
    "        if shuffle:\n",
    "            sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]\n",
    "        else:\n",
    "            sampler = SequentialSampler(dataset)  # type: ignore[arg-type]\n",
    "```\n",
    "\n",
    "\n",
    "关于 `batch_sampler`，其实就是一个个地取数据，满了就返回一个batch的index：\n",
    "\n",
    "```py\n",
    "def __iter__(self) -> Iterator[list[int]]:\n",
    "    # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951\n",
    "    sampler_iter = iter(self.sampler)\n",
    "    if self.drop_last:\n",
    "        # Create multiple references to the same iterator\n",
    "        args = [sampler_iter] * self.batch_size\n",
    "        for batch_droplast in zip(*args):\n",
    "            yield [*batch_droplast]\n",
    "    else:\n",
    "        batch = [*itertools.islice(sampler_iter, self.batch_size)]\n",
    "        while batch:\n",
    "            yield batch\n",
    "            batch = [*itertools.islice(sampler_iter, self.batch_size)]\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2394b80b",
   "metadata": {},
   "source": [
    "## collate_fn\n",
    "\n",
    "```py\n",
    "if collate_fn is None:\n",
    "    if self._auto_collation:\n",
    "        collate_fn = _utils.collate.default_collate\n",
    "    else:\n",
    "        collate_fn = _utils.collate.default_convert\n",
    "```\n",
    "\n",
    "默认使用`default_collate`，只做了点数据转换，其他的什么事也不干TAT"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
